set -e

# install the requirements and the main repo
pip install -r requirements.txt

# update some libraries
pip install wheel
pip install -U torch torchvision
pip install transformers==4.47.1
pip install accelerate==1.2.1
pip install flash-attn --no-build-isolation
pip install fire

# install this repo
cd third-party/
git clone https://github.com/NVIDIA/cutlass.git
cd cutlass
git checkout ffa34e70756b0bc744e1dfcc115b5a991a68f132
cd ../../
pip install -e .

# install int8 kernels
cd gemm-int8
git clone https://github.com/NVIDIA/cutlass.git
cd cutlass
git checkout 902dff366310fe3b0279c1149d5ea6bb7ea0b715
cd ../
pip install -e .
cd ../

# install fp8 kernels
cd gemm-fp8
git clone https://github.com/NVIDIA/cutlass.git
cd cutlass
git checkout 902dff366310fe3b0279c1149d5ea6bb7ea0b715
cd ../
pip install -e .
cd ../

# (optional) install baseline kernels (JetFire)
git clone https://github.com/thu-ml/Jetfire-INT8Training.git
cd Jetfire-INT8Training/JetfireGEMMKernel
python setup.py install
cd ../..

# # (optional) install HALO peft
# cd peft && pip install -e . && cd ..